{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import collections"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_dataset(words, n_words):\n",
    "    count = [['GO', 0], ['PAD', 1], ['EOS', 2], ['UNK', 3]]\n",
    "    count.extend(collections.Counter(words).most_common(n_words - 1))\n",
    "    dictionary = dict()\n",
    "    for word, _ in count:\n",
    "        dictionary[word] = len(dictionary)\n",
    "    data = list()\n",
    "    unk_count = 0\n",
    "    for word in words:\n",
    "        index = dictionary.get(word, 0)\n",
    "        if index == 0:\n",
    "            unk_count += 1\n",
    "        data.append(index)\n",
    "    count[0][1] = unk_count\n",
    "    reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))\n",
    "    return data, count, dictionary, reversed_dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "len from: 265, len to: 265\n"
     ]
    }
   ],
   "source": [
    "with open('from.txt', 'r') as fopen:\n",
    "    text_from = fopen.read().lower().split('\\n')\n",
    "with open('to.txt', 'r') as fopen:\n",
    "    text_to = fopen.read().lower().split('\\n')\n",
    "print('len from: %d, len to: %d'%(len(text_from), len(text_to)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab from size: 331\n",
      "Most common words [('you', 73), ('is', 67), ('what', 63), ('a', 49), ('the', 40), ('do', 36)]\n",
      "Sample data [129, 61, 152, 61, 238, 61, 324, 61, 258, 12] ['hi', 'good', 'morning', 'good', 'afternoon', 'good', 'evening', 'good', 'night', 'how']\n"
     ]
    }
   ],
   "source": [
    "concat_from = ' '.join(text_from).split()\n",
    "vocabulary_size_from = len(list(set(concat_from)))\n",
    "data_from, count_from, dictionary_from, rev_dictionary_from = build_dataset(concat_from, vocabulary_size_from)\n",
    "print('vocab from size: %d'%(vocabulary_size_from))\n",
    "print('Most common words', count_from[4:10])\n",
    "print('Sample data', data_from[:10], [rev_dictionary_from[i] for i in data_from[:10]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab to size: 504\n",
      "Most common words [('i', 127), ('you', 55), ('a', 47), ('to', 44), ('the', 40), ('it', 38)]\n",
      "Sample data [205, 120, 36, 224, 36, 440, 36, 270, 36, 461] ['hi', 'there', 'good', 'morning', 'good', 'afternoon', 'good', 'evening', 'good', 'night']\n"
     ]
    }
   ],
   "source": [
    "concat_to = ' '.join(text_to).split()\n",
    "vocabulary_size_to = len(list(set(concat_to)))\n",
    "data_to, count_to, dictionary_to, rev_dictionary_to = build_dataset(concat_to, vocabulary_size_to)\n",
    "print('vocab to size: %d'%(vocabulary_size_to))\n",
    "print('Most common words', count_to[4:10])\n",
    "print('Sample data', data_to[:10], [rev_dictionary_to[i] for i in data_to[:10]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "GO = dictionary_from['GO']\n",
    "PAD = dictionary_from['PAD']\n",
    "EOS = dictionary_from['EOS']\n",
    "UNK = dictionary_from['UNK']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Chatbot:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size,\n",
    "                 from_dict_size, to_dict_size, learning_rate, batch_size):\n",
    "        \n",
    "        def cells(reuse=False):\n",
    "            return tf.nn.rnn_cell.BasicRNNCell(size_layer,reuse=reuse)\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        self.X_seq_len = tf.placeholder(tf.int32, [None])\n",
    "        self.Y_seq_len = tf.placeholder(tf.int32, [None])\n",
    "        \n",
    "        encoder_embeddings = tf.Variable(tf.random_uniform([from_dict_size, embedded_size], -1, 1))\n",
    "        decoder_embeddings = tf.Variable(tf.random_uniform([to_dict_size, embedded_size], -1, 1))\n",
    "        encoder_embedded = tf.nn.embedding_lookup(encoder_embeddings, self.X)\n",
    "        main = tf.strided_slice(self.X, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        decoder_embedded = tf.nn.embedding_lookup(encoder_embeddings, decoder_input)\n",
    "        \n",
    "        attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units = size_layer, \n",
    "                                                                    memory = encoder_embedded)\n",
    "        rnn_cells = tf.contrib.seq2seq.AttentionWrapper(cell = tf.nn.rnn_cell.MultiRNNCell([cells() for _ in range(num_layers)]), \n",
    "                                                        attention_mechanism = attention_mechanism,\n",
    "                                                        attention_layer_size = size_layer)\n",
    "        _, last_state = tf.nn.dynamic_rnn(rnn_cells, encoder_embedded,\n",
    "                                          dtype = tf.float32)\n",
    "        last_state = tuple(last_state[0][-1] for _ in range(num_layers))\n",
    "        with tf.variable_scope(\"decoder\"):\n",
    "            rnn_cells_dec = tf.nn.rnn_cell.MultiRNNCell([cells() for _ in range(num_layers)])\n",
    "            outputs, _ = tf.nn.dynamic_rnn(rnn_cells_dec, decoder_embedded, \n",
    "                                           initial_state = last_state,\n",
    "                                           dtype = tf.float32)\n",
    "        self.logits = tf.layers.dense(outputs,to_dict_size)\n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.logits,\n",
    "                                                     targets = self.Y,\n",
    "                                                     weights = masks)\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(self.cost)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 128\n",
    "num_layers = 2\n",
    "embedded_size = 128\n",
    "learning_rate = 0.001\n",
    "batch_size = 32\n",
    "epoch = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Chatbot(size_layer, num_layers, embedded_size, vocabulary_size_from + 4, \n",
    "                vocabulary_size_to + 4, learning_rate, batch_size)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def str_idx(corpus, dic):\n",
    "    X = []\n",
    "    for i in corpus:\n",
    "        ints = []\n",
    "        for k in i.split():\n",
    "            try:\n",
    "                ints.append(dic[k])\n",
    "            except Exception as e:\n",
    "                print(e)\n",
    "                ints.append(2)\n",
    "        X.append(ints)\n",
    "    return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'nothing'\n",
      "'mention'\n"
     ]
    }
   ],
   "source": [
    "X = str_idx(text_from, dictionary_from)\n",
    "Y = str_idx(text_to, dictionary_to)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_sentence_batch(sentence_batch, pad_int):\n",
    "    padded_seqs = []\n",
    "    seq_lens = []\n",
    "    max_sentence_len = 50\n",
    "    for sentence in sentence_batch:\n",
    "        padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))\n",
    "        seq_lens.append(50)\n",
    "    return padded_seqs, seq_lens\n",
    "\n",
    "def check_accuracy(logits, Y):\n",
    "    acc = 0\n",
    "    for i in range(logits.shape[0]):\n",
    "        internal_acc = 0\n",
    "        for k in range(len(Y[i])):\n",
    "            if Y[i][k] == logits[i][k]:\n",
    "                internal_acc += 1\n",
    "        acc += (internal_acc / len(Y[i]))\n",
    "    return acc / logits.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, avg loss: 4.023362, avg accuracy: 0.742500\n",
      "epoch: 2, avg loss: 1.370480, avg accuracy: 0.867813\n",
      "epoch: 3, avg loss: 1.029058, avg accuracy: 0.868828\n",
      "epoch: 4, avg loss: 0.974061, avg accuracy: 0.868906\n",
      "epoch: 5, avg loss: 0.949733, avg accuracy: 0.869141\n",
      "epoch: 6, avg loss: 0.923775, avg accuracy: 0.869297\n",
      "epoch: 7, avg loss: 0.907751, avg accuracy: 0.869766\n",
      "epoch: 8, avg loss: 0.894993, avg accuracy: 0.869922\n",
      "epoch: 9, avg loss: 0.883858, avg accuracy: 0.870313\n",
      "epoch: 10, avg loss: 0.873926, avg accuracy: 0.871094\n",
      "epoch: 11, avg loss: 0.865184, avg accuracy: 0.870938\n",
      "epoch: 12, avg loss: 0.856581, avg accuracy: 0.871328\n",
      "epoch: 13, avg loss: 0.849098, avg accuracy: 0.871016\n",
      "epoch: 14, avg loss: 0.840666, avg accuracy: 0.871172\n",
      "epoch: 15, avg loss: 0.833286, avg accuracy: 0.871016\n",
      "epoch: 16, avg loss: 0.824797, avg accuracy: 0.871172\n",
      "epoch: 17, avg loss: 0.815313, avg accuracy: 0.871406\n",
      "epoch: 18, avg loss: 0.804944, avg accuracy: 0.871875\n",
      "epoch: 19, avg loss: 0.793187, avg accuracy: 0.871953\n",
      "epoch: 20, avg loss: 0.781281, avg accuracy: 0.871797\n",
      "epoch: 21, avg loss: 0.773517, avg accuracy: 0.872031\n",
      "epoch: 22, avg loss: 0.768158, avg accuracy: 0.872109\n",
      "epoch: 23, avg loss: 0.764721, avg accuracy: 0.872188\n",
      "epoch: 24, avg loss: 0.765491, avg accuracy: 0.873516\n",
      "epoch: 25, avg loss: 0.758725, avg accuracy: 0.871953\n",
      "epoch: 26, avg loss: 0.741139, avg accuracy: 0.872656\n",
      "epoch: 27, avg loss: 0.729767, avg accuracy: 0.872891\n",
      "epoch: 28, avg loss: 0.720276, avg accuracy: 0.873281\n",
      "epoch: 29, avg loss: 0.711709, avg accuracy: 0.873672\n",
      "epoch: 30, avg loss: 0.703955, avg accuracy: 0.873828\n",
      "epoch: 31, avg loss: 0.696774, avg accuracy: 0.874219\n",
      "epoch: 32, avg loss: 0.689044, avg accuracy: 0.874766\n",
      "epoch: 33, avg loss: 0.684563, avg accuracy: 0.874766\n",
      "epoch: 34, avg loss: 0.733471, avg accuracy: 0.874297\n",
      "epoch: 35, avg loss: 0.788000, avg accuracy: 0.871094\n",
      "epoch: 36, avg loss: 0.786791, avg accuracy: 0.872813\n",
      "epoch: 37, avg loss: 0.740085, avg accuracy: 0.872656\n",
      "epoch: 38, avg loss: 0.716111, avg accuracy: 0.875781\n",
      "epoch: 39, avg loss: 0.697435, avg accuracy: 0.876016\n",
      "epoch: 40, avg loss: 0.684894, avg accuracy: 0.876328\n",
      "epoch: 41, avg loss: 0.674476, avg accuracy: 0.875469\n",
      "epoch: 42, avg loss: 0.665863, avg accuracy: 0.877813\n",
      "epoch: 43, avg loss: 0.658186, avg accuracy: 0.877813\n",
      "epoch: 44, avg loss: 0.651171, avg accuracy: 0.878203\n",
      "epoch: 45, avg loss: 0.644903, avg accuracy: 0.878828\n",
      "epoch: 46, avg loss: 0.638759, avg accuracy: 0.878672\n",
      "epoch: 47, avg loss: 0.632353, avg accuracy: 0.879297\n",
      "epoch: 48, avg loss: 0.626532, avg accuracy: 0.880156\n",
      "epoch: 49, avg loss: 0.620819, avg accuracy: 0.881016\n",
      "epoch: 50, avg loss: 0.615854, avg accuracy: 0.881641\n"
     ]
    }
   ],
   "source": [
    "for i in range(epoch):\n",
    "    total_loss, total_accuracy = 0, 0\n",
    "    for k in range(0, (len(text_from) // batch_size) * batch_size, batch_size):\n",
    "        batch_x, seq_x = pad_sentence_batch(X[k: k+batch_size], PAD)\n",
    "        batch_y, seq_y = pad_sentence_batch(Y[k: k+batch_size], PAD)\n",
    "        predicted, loss, _ = sess.run([tf.argmax(model.logits,2), model.cost, model.optimizer], \n",
    "                                      feed_dict={model.X:batch_x,\n",
    "                                                model.Y:batch_y,\n",
    "                                                model.X_seq_len:seq_x,\n",
    "                                                model.Y_seq_len:seq_y})\n",
    "        total_loss += loss\n",
    "        total_accuracy += check_accuracy(predicted,batch_y)\n",
    "    total_loss /= (len(text_from) // batch_size)\n",
    "    total_accuracy /= (len(text_from) // batch_size)\n",
    "    print('epoch: %d, avg loss: %f, avg accuracy: %f'%(i+1, total_loss, total_accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
